{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "8cd3d2bb",
   "metadata": {},
   "source": [
    "# Lesson 3 - Continuous Batching\n",
    "\n",
    "In this lesson, we'll discuss the production set up of \"batching\" in LLM inference, \"Continuous batching\".\n",
    "\n",
    "- The key idea behind continuous batching is constantly swap out requests from the batch that have completed generation for requests in the queue that are waiting to be processed."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a4c7beff",
   "metadata": {},
   "source": [
    "### Import required packages and load the LLM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "880bf5c8",
   "metadata": {
    "height": 97
   },
   "outputs": [],
   "source": [
    "# Import all needed functions from Lesson 1 and 2\n",
    "\n",
    "import helpers\n",
    "from helpers import init_batch, generate_next_token\n",
    "from helpers import merge_batches, filter_batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd758b28-27a7-4188-b6e6-a900e8d0b0ac",
   "metadata": {
    "height": 165
   },
   "outputs": [],
   "source": [
    "import copy\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import random\n",
    "import time\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from tqdm import tqdm\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6211946-5f97-4ea8-a97d-469f769df8cd",
   "metadata": {
    "height": 63
   },
   "outputs": [],
   "source": [
    "model_name = \"./models/gpt2\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "model = AutoModelForCausalLM.from_pretrained(model_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4cee43e1",
   "metadata": {},
   "source": [
    "### Add padding tokens to the model to prepare batches of prompts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f30c705a-e94c-4ef7-be25-3c7441c67531",
   "metadata": {
    "height": 131
   },
   "outputs": [],
   "source": [
    "# Define PAD Token = EOS Token = 50256\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "model.config.pad_token_id = model.config.eos_token_id\n",
    "\n",
    "# pad on the left so we can append new tokens on the right\n",
    "tokenizer.padding_side = \"left\"\n",
    "tokenizer.truncation_side = \"left\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97f624bc-aaa6-404e-8e8a-b4361af15337",
   "metadata": {
    "height": 182
   },
   "outputs": [],
   "source": [
    "# multiple prompts of varying lengths to send to the model at once\n",
    "prompts = [\n",
    "    \"The quick brown fox jumped over the\",\n",
    "    \"The rain in Spain falls\",\n",
    "    \"What comes up must\",\n",
    "]\n",
    "\n",
    "# note: padding=True ensures the padding token will be inserted into the tokenized tensors\n",
    "inputs = tokenizer(prompts, padding=True, return_tensors=\"pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "498c3b6d",
   "metadata": {},
   "source": [
    "### Define needed functions for batching"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f799f2fd-483e-4d41-a807-1bed7ded754f",
   "metadata": {
    "height": 675
   },
   "outputs": [],
   "source": [
    "def generate_batch_tokens_with_past(inputs):\n",
    "    with torch.no_grad():\n",
    "        outputs = model(**inputs)\n",
    "\n",
    "    logits = outputs.logits\n",
    "    last_logits = logits[:, -1, :]\n",
    "    next_token_ids = last_logits.argmax(dim=1)\n",
    "    return next_token_ids, outputs.past_key_values\n",
    "\n",
    "\n",
    "def generate_batch(inputs, max_tokens):\n",
    "    # create a list of tokens for every input in the batch\n",
    "    generated_tokens = [[] for _ in range(inputs[\"input_ids\"].shape[0])]\n",
    "    \n",
    "    attention_mask = inputs[\"attention_mask\"]\n",
    "    position_ids = attention_mask.long().cumsum(-1) - 1\n",
    "    position_ids.masked_fill_(attention_mask == 0, 1)\n",
    "    \n",
    "    next_inputs = {\n",
    "        \"position_ids\": position_ids,\n",
    "        **inputs\n",
    "    }\n",
    "    for _ in range(max_tokens):\n",
    "        next_token_ids, past_key_values = generate_batch_tokens_with_past(next_inputs)\n",
    "        next_inputs = {\n",
    "            \"input_ids\": next_token_ids.reshape((-1, 1)),  # '-1' here means the remaining elements for this dim\n",
    "            \"position_ids\": next_inputs[\"position_ids\"][:, -1].unsqueeze(-1) + 1,  # increment last, discard the rest\n",
    "            \"attention_mask\": torch.cat([\n",
    "                next_inputs[\"attention_mask\"],\n",
    "                torch.ones((next_token_ids.shape[0], 1)),  # concatenate vector of 1's with shape [batch_size]\n",
    "            ], dim=1),\n",
    "            \"past_key_values\": past_key_values,\n",
    "        }\n",
    "\n",
    "        next_tokens = tokenizer.batch_decode(next_token_ids)\n",
    "        for i, token in enumerate(next_tokens):\n",
    "            generated_tokens[i].append(token)\n",
    "    return [\"\".join(tokens) for tokens in generated_tokens]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "334ec741",
   "metadata": {},
   "source": [
    "### Define the requests to be processed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d785a01-5cfa-487e-9cbe-5d3d6e5c48b0",
   "metadata": {
    "height": 250
   },
   "outputs": [],
   "source": [
    "# seed the random number generator so our results are deterministic\n",
    "random.seed(42)\n",
    "\n",
    "# constants\n",
    "queue_size = 32\n",
    "batch_size = 8\n",
    "\n",
    "# requests waiting to be processed\n",
    "# requests are tuples (prompt, max_tokens)\n",
    "request_queue = [\n",
    "    (prompts[0], 100 if i % batch_size == 0 else 10)\n",
    "    for i in range(queue_size)\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6cad0536-9d5b-43ab-8015-a43aaacbb281",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "request_queue[:8]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07543da0-26b9-4ba1-bc99-5c42045c8078",
   "metadata": {
    "height": 80
   },
   "outputs": [],
   "source": [
    "batches = [\n",
    "    request_queue[i:i + batch_size]\n",
    "    for i in range(0, len(request_queue), batch_size)\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b895c985-9964-4d58-952f-050446a4ae83",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "len(batches)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91a61679-392b-4ff9-9e92-9deb35ef8305",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "batches[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b98bb175",
   "metadata": {},
   "source": [
    "### Processing batches \n",
    "\n",
    "**Note:** Your results might differ somewhat from those shown in the video, but they will still follow the same pattern as explained by the instructor."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "017cc842-95ef-4307-9996-5832644a6cc3",
   "metadata": {
    "height": 352
   },
   "outputs": [],
   "source": [
    "# generate tokens for all batches and record duration\n",
    "t0 = time.time()\n",
    "with tqdm(total=len(batches), desc=f\"bs={batch_size}\") as pbar:\n",
    "    for i, batch in enumerate(batches):\n",
    "        # to accommodate all the requests with our \n",
    "        # current implementation, we take the max of\n",
    "        # all the tokens to generate among the requests\n",
    "        batch_max_tokens = [b[1] for b in batch]\n",
    "        max_tokens = max(batch_max_tokens)\n",
    "        pbar.set_postfix({'max_tokens': max_tokens})\n",
    "        \n",
    "        batch_prompts = [b[0] for b in batch]\n",
    "        inputs = tokenizer(\n",
    "            batch_prompts, padding=True, return_tensors=\"pt\")\n",
    "        generate_batch(inputs, max_tokens=max_tokens)\n",
    "        \n",
    "        pbar.update(1)\n",
    "\n",
    "duration_s = time.time() - t0\n",
    "print(\"duration\", duration_s)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "04949ab4",
   "metadata": {},
   "source": [
    "### Let's try continuous batching\n",
    "\n",
    "- This time, rather than processing each batch to completion, you will use continuous batching to dynamically swap in and out inputs from the queue.\n",
    "\n",
    "**Note:** Your results might differ somewhat from those shown in the video, but they will still follow the same pattern as explained by the instructor."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87a92e27-2f7a-424d-8fd7-108784ec759e",
   "metadata": {
    "height": 896
   },
   "outputs": [],
   "source": [
    "# seed the random number generator so our results are deterministic\n",
    "random.seed(42)\n",
    "\n",
    "# constants\n",
    "queue_size = 32\n",
    "batch_size = 8\n",
    "\n",
    "# requests waiting to be processed\n",
    "# this time requests are tuples (prompt, max_tokens)\n",
    "request_queue = [\n",
    "    (prompts[0], 100 if i % batch_size == 0 else 10)\n",
    "    for i in range(queue_size)\n",
    "]\n",
    "\n",
    "t0 = time.time()\n",
    "with tqdm(total=len(request_queue), desc=f\"bs={batch_size}\") as pbar:\n",
    "    # first, let's seed the initial cached_batch\n",
    "    # with the first `batch_size` inputs\n",
    "    # and run the initial prefill step\n",
    "    batch = init_batch(request_queue[:batch_size])\n",
    "    cached_batch = generate_next_token(batch)\n",
    "    request_queue = request_queue[batch_size:]\n",
    "\n",
    "    # continue until both the request queue is \n",
    "    # fully drained and every input\n",
    "    # within the cached_batch has completed generation\n",
    "    while (\n",
    "        len(request_queue) > 0 or\n",
    "        cached_batch[\"input_ids\"].size(0) > 0\n",
    "    ):\n",
    "        batch_capacity = (\n",
    "            batch_size - cached_batch[\"input_ids\"].size(0)\n",
    "        )\n",
    "        if batch_capacity > 0 and len(request_queue) > 0:\n",
    "            # prefill\n",
    "            new_batch = init_batch(request_queue[:batch_capacity])\n",
    "            new_batch = generate_next_token(new_batch)\n",
    "            request_queue = request_queue[batch_capacity:]\n",
    "\n",
    "            # merge\n",
    "            cached_batch = merge_batches(cached_batch, new_batch)\n",
    "\n",
    "        # decode\n",
    "        cached_batch = generate_next_token(cached_batch)\n",
    "\n",
    "        # remove any inputs that have finished generation\n",
    "        cached_batch, removed_indices = filter_batch(cached_batch)\n",
    "        pbar.update(len(removed_indices))\n",
    "\n",
    "duration_s = time.time() - t0\n",
    "print(\"duration\", duration_s)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
